iT邦幫忙

2025 iThome 鐵人賽

DAY 16
0

昨天我們已經完成了 LLaMA2 的 Attention 模組,今天我們要繼續完成剩下的部分

LLaMA2 MLP 模塊

  MLP(Multi-Layer Perceptron)是 Transformer 每個 Block 中 FNN 的實作,在 LLaMA2 的設計中,MLP 採用了 SwiGLU 結構,這是相較於傳統 FNN 的一大改進,一般的 Transformer 只會透過單一路徑的激勵函數來進行非線性轉換,而 LLaMA2 則同時引入兩個線性變換,並透過 SiLU 激活與逐元素相乘的方式,讓網路在相同參數量下具有更強的表達能力,這樣的設計能夠更好地捕捉語言中的高階語義關係,也提升了深層網路的訓練穩定性。

class MLP(nn.Module):
    def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float):
        super().__init__()
        if hidden_dim is None:
            hidden_dim = 4 * dim
            hidden_dim = int(2 * hidden_dim / 3)
            hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)
        self.w3 = nn.Linear(dim, hidden_dim, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))

LLaMA2 Decoder Layer

Decoder Layer 把 Attention + MLP + Norm 組合起來

class DecoderLayer(nn.Module):
    def __init__(self, layer_id: int, args: ModelConfig):
        super().__init__()
        self.attention = Attention(args)
        self.feed_forward = MLP(args.dim, args.hidden_dim, args.multiple_of, args.dropout)
        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)

    def forward(self, x, freqs_cos, freqs_sin):
        h = x + self.attention(self.attention_norm(x), freqs_cos, freqs_sin)
        out = h + self.feed_forward(self.ffn_norm(h))
        return out

LLaMA2 模型

  當我們實作完成 DecoderLayer 之後,就能夠進一步堆疊多層,構建出一個完整的 LLaMA2 模型,整體架構的流程是:
  首先輸入的 token 會先經過詞嵌入層(Embedding)轉換為向量表示,接著依序通過多層 Decoder Layer,每一層都包含 Attention、FNN、Norm 等子模組,透過層層疊加不斷提取更高階的語義資訊,最後最上層的輸出會通過一個線性層將特徵轉換回詞彙空間,用於預測下一個詞的機率分佈。

class Transformer(PreTrainedModel):
    config_class = ModelConfig

    def __init__(self, args: ModelConfig = None):
        super().__init__(args)
        self.args = args
        self.vocab_size = args.vocab_size
        self.n_layers = args.n_layers

        self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
        self.dropout = nn.Dropout(args.dropout)
        self.layers = nn.ModuleList([DecoderLayer(i, args) for i in range(args.n_layers)])
        self.norm = RMSNorm(args.dim, eps=args.norm_eps)
        self.output = nn.Linear(args.dim, args.vocab_size, bias=False)

        # 權重共享:Embedding <-> Output
        self.tok_embeddings.weight = self.output.weight

        # 預先計算 RoPE 頻率
        freqs_cos, freqs_sin = precompute_freqs_cis(self.args.dim // self.args.n_heads, self.args.max_seq_len)
        self.register_buffer("freqs_cos", freqs_cos, persistent=False)
        self.register_buffer("freqs_sin", freqs_sin, persistent=False)

    def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None):
        bsz, seqlen = tokens.shape
        h = self.dropout(self.tok_embeddings(tokens))
        freqs_cos, freqs_sin = self.freqs_cos[:seqlen], self.freqs_sin[:seqlen]

        for layer in self.layers:
            h = layer(h, freqs_cos, freqs_sin)

        h = self.norm(h)
        logits = self.output(h)

        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=0)
            return logits, loss
        return logits

    @torch.inference_mode()
    def generate(self, idx, max_new_tokens=50, temperature=1.0, top_k=None):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.args.max_seq_len:]
            logits = self(idx_cond)
            logits = logits[:, -1, :] / temperature

            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')

            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

  這種模組化的設計讓模型具備良好的可擴展性與可微調性,只要調整層數和維度大小,就能快速適應不同規模的應用場景。

參考連結:
https://datawhalechina.github.io/happy-llm/#/


上一篇
[Day15] 實作一個 LLaMA2 模型 (二)
下一篇
[Day17] 訓練我們的 Tokenizer!
系列文
從上下文工程到 Agent:30 天生成式 AI 與 LLM 學習紀錄21
圖片
  熱門推薦
圖片
{{ item.channelVendor }} | {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言